# Kernel: sgemm_tn_192x192

# Copyright 2014 Nervana Systems Inc. All rights reserved.
# Copyright 2015~2016 Xiuxia Zhang zhangxiuxia@ict.ac.cn
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#    http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


<CONSTANT_MAPPING>
    addr_zero  : 4x<256*4*4>

    gridDimA : c[0x0][0x14]
    gridDimB : c[0x0][0x18]

    param_Rand[0]   : c[0x0][0x140]
    param_Rand[1]   : c[0x0][0x144]
    param_A[0]      : c[0x0][0x148]
    param_A[1]      : c[0x0][0x14c]
    param_B[0]      : c[0x0][0x150]
    param_B[1]      : c[0x0][0x154]
    param_C[0]      : c[0x0][0x158]
    param_C[1]      : c[0x0][0x15c]
    param_lda8      : c[0x0][0x160]
    param_ldb8      : c[0x0][0x164]
    param_ldc       : c[0x0][0x168]
    param_m         : c[0x0][0x16c]
    param_n         : c[0x0][0x170]
    param_k         : c[0x0][0x174]
    param_alpha     : c[0x0][0x178]
    param_beta      : c[0x0][0x17c]
    param_flags     : c[0x0][0x180]
    param_ldaz      : c[0x0][0x184]
    param_ldbz      : c[0x0][0x188]
    param_ldcz      : c[0x0][0x18c]
    param_debug[0]     : c[0x0][0x190]
    param_debug[1]     : c[0x0][0x194]
</CONSTANT_MAPPING>

<REGISTER_MAPPING>

    144-191   ~ blkA, blkB, blkZ, lda, ldb, ldaz, ldbz, tid1, tid7, tidX, blk, tid63, tid128, dimA, flag, tbid, param_A_base_lo, param_A_base_hi, ta_shl, param_B_base_lo, param_B_base_hi,tb_shl, txa_shl

    0-143    : czero<00-143>

      1,  4, 25, 28, 49, 52, 73, 76, 97,100,121,124 : cx<0-11>y0
      5,  0, 29, 24, 53, 48, 77, 72,101, 96,125,120 : cx<0-11>y1
      3,  6, 27, 30, 51, 54, 75, 78, 99,102,123,126 : cx<0-11>y2
      7,  2, 31, 26, 55, 50, 79, 74,103, 98,127,122 : cx<0-11>y3
      9, 12, 33, 36, 57, 60, 81, 84,105,108,129,132 : cx<0-11>y4
     13,  8, 37, 32, 61, 56, 85, 80,109,104,133,128 : cx<0-11>y5
     11, 14, 35, 38, 59, 62, 83, 86,107,110,131,134 : cx<0-11>y6
     15, 10, 39, 34, 63, 58, 87, 82,111,106,135,130 : cx<0-11>y7
     17, 20, 41, 44, 65, 68, 89, 92,113,116,137,140 : cx<0-11>y8
     21, 16, 45, 40, 69, 64, 93, 88,117,112,141,136 : cx<0-11>y9
     19, 22, 43, 46, 67, 70, 91, 94,115,118,139,142 : cx<0-11>y10
     23, 18, 47, 42, 71, 66, 95, 90,119,114,143,138 : cx<0-11>y11

    #xsb
    144-191   ~ x<1-3>, y<1-3>

    144-147   : j0Ay<0-3>
    152-155   : j0Ay<4-7>
    160-163   : j0Ay<8-11>
    148-151   : j0Bx<0-3>
    156-159   : j0Bx<4-7>
    164-167   : j0Bx<8-11>
    168-171   : j1Ay<0-3>
    176-179   : j1Ay<4-7>
    184-187   : j1Ay<8-11>
    172-175   : j1Bx<0-3>
    180-183   : j1Bx<4-7>
    188-191   : j1Bx<8-11>

    192-197   : loadA<0-5>
    200-205   : loadB<0-5>

    198-199   : trackA<0-1>
    234-235   : trackB<0-1>

    208-221 ~ lda8, k, k7, tidY, txa, txb, ta, tb, loop
    228     ~ writeS
    222     ~ readAs
    223     ~ readBs
    224     ~ tid
    225     ~ seed
    226-227 : Rand<0-1>

    144-155   ~ ldc, ci, xmad_c, clk_shf1, clk_shf2, tid_31, tid_96, tid_128, mantissa_bits, blockA, blockB, blockZ
    144-159   : c<0-11>, d3, d2, d1, d0
    160-167   : C00y<0-1>, C04y<0-1>, C08y<0-1>, C12y<0-1>
    168-203  ~ ldc1, ldc4, ldc60, ldcz, writeCs, readCs, cx<00|64|128>, cy<00|04|08|12>, alpha, beta, flags, exp_scale, trunc_mask, lfsr<0-2>, exp<0-3>, rand<0-3>, lfsr<0-2>_1, lfsr<0-2>_2
    240-241 : debug<0-1>

</REGISTER_MAPPING>

-:-:-:-:00 S2R tid, SR_TID.X;
-:-:-:-:00 S2R blkA, SR_CTAID.Y;
-:-:-:-:00 S2R blkB, SR_CTAID.Z;
-:-:-:-:00 S2R blkZ, SR_CTAID.X;

-:-:-:-:00 MOV k, param_k;
-:-:-:-:00 MOV loop, RZ;
-:-:-:-:00 STS.128 [RZ + addr_zero], RZ;
<CODE>
join('', map sprintf("-:-:-:-:00 LDS.U.128 czero%02d, [RZ + addr_zero];\n", $_ * 4), 0..35);
</CODE>

-:-:-:-:00 LOP.AND tid63, tid, 63;
-:-:-:-:00 SHL tidX, tid63, 2;
-:-:-:-:00 BFE.U32 tidY, tid, 0x206; // 2 bits at position 6

-:-:-:-:00 MOV lda, param_lda8;
-:-:-:-:00 MOV ldb, param_ldb8;
-:-:-:-:00 SHR.U32 lda, lda, 4;
-:-:-:-:00 SHR.U32 ldb, ldb, 4;
-:-:-:-:00 MOV ldaz, param_ldaz;
-:-:-:-:00 MOV ldbz, param_ldbz;

-:-:-:-:00 SHL    txa, blkA, 7;
-:-:-:-:00 ISCADD txa, blkA, txa, 6;
-:-:-:-:00 IADD   txa, txa, tidX;
-:-:-:-:00 IMAD ta, lda, tidY, txa;
-:-:-:-:00 ISCADD txa_shl, tidY, txa, 8;
-:-:-:-:00 SHL txa_shl, txa_shl, 2;


-:-:-:-:00 MOV param_A_base_lo, param_A[0];
-:-:-:-:00 MOV param_A_base_hi, param_A[1];
-:-:-:-:00 SHL ta_shl, ta, 2;
-:-:-:-:00 IADD trackA0.CC, ta_shl, param_A_base_lo;
-:-:-:-:00 IADD.X trackA1, RZ, param_A_base_hi;
-:-:-:-:00 ISETP.LT.AND P5, PT, txa, param_m, PT;
-:-:-:-:00 SHL    txb, blkB, 7;
-:-:-:-:00 ISCADD txb, blkB, txb, 6;
-:-:-:-:00 IADD txb, txb, tidX;

-:-:-:-:00 IMAD tb, ldb, tidY, txb;
-:-:-:-:00 MOV param_B_base_lo, param_B[0];
-:-:-:-:00 MOV param_B_base_hi, param_B[1];
-:-:-:-:00 SHL tb_shl, tb, 2;
-:-:-:-:00 IADD trackB0.CC, tb_shl, param_B_base_lo;
-:-:-:-:00 IADD.X trackB1, RZ, param_B_base_hi;
-:-:-:-:00 ISETP.LT.AND P6, PT, txb, param_n, PT;
-:-:-:-:00 SHR tid1, tidX, 1;
-:-:-:-:00 ISCADD writeS, tidY, tid1, 8;
-:-:-:-:00 SHL writeS, writeS, 2;
-:-:-:-:00 LOP.XOR writeS, writeS, 4x<256*4*2>;

-:-:-:-:00 LOP.AND tid1, tid, 1;
-:-:-:-:00 LOP.AND readAs, tid, 0x70;
-:-:-:-:00 SHR.U32 readAs, readAs, 3;
-:-:-:-:00 LOP.OR readAs, readAs, tid1;
-:-:-:-:00 SHL readAs, readAs, 3;


-:-:-:-:00 LOP.AND tid128, tid, 128;
-:-:-:-:00 BFE.U32 tid7, tid, 0x301; // 3 bits at position 1
-:-:-:-:00 SHR.U32 readBs, tid128, 4;
-:-:-:-:00 LOP.OR readBs, readBs, tid7;
-:-:-:-:00 ISCADD readBs, readBs, 4x<256*4>, 3;
-:-:-:-:00 MOV flag, param_flags;
-:-:-:-:00 MOV dimA, gridDimA;
-:-:-:-:00 ISCADD tbid, blkA, tid, 8;
-:-:-:-:00 IMAD dimA, blkB, dimA, RZ;
-:-:-:-:00 ISCADD tbid, dimA, tbid, 8;
-:-:-:-:00 LOP.AND seed, tbid, 1x<2048*32 - 1>;

-:-:-:-:00 MOV param_A_base_lo, param_debug[0];
-:-:-:-:00 MOV param_A_base_hi, param_debug[1];
-:-:-:-:00 MOV debug0, RZ;
-:-:-:-:00 MOV debug1, RZ;
-:-:-:-:00 IADD debug0.CC, txa_shl, param_A_base_lo;
-:-:-:-:00 IADD.X debug1, param_A_base_hi, RZ;

REMAINDER:


<CODE>
    return q{
-:-:-:-:00 LOP.AND k7, k, 3;
-:-:-:-:00 ISETP.EQ.AND P4, PT, k7, RZ, PT;
-:-:-:-:00 ISETP.GT.AND P1, PT, k, 4, !P4;

-:-:-:-:00 ISETP.LT.AND P2, PT, tidY, k, P5;
-:-:-:-:00 ISETP.LT.AND P3, PT, tidY, k, P6;
-:-:-:-:00 @P2 LDG.E.128 loadA, [trackA];
-:-:-:-:00 @P3 LDG.E.128 loadB, [trackB];

-:-:-:-:00 @!P2 LDS.U.128 loadA, [RZ + addr_zero];
-:-:-:-:00 @!P3 LDS.U.128 loadB, [RZ + addr_zero];
    };
</CODE>

T:-:D:S:00 TEXDEPBAR 0x0;
-:-:-:-:00 NOP;
-:-:-:-:00 NOP;
-:-:-:-:00 NOP;
-:-:-:-:00 NOP;
-:-:-:-:00 NOP;
-:-:-:-:00 NOP;

-:-:-:-:00 @P5 STS.64 [writeS + 4x<0*256>], loadA0;
-:-:-:-:00 @P5 STS.64 [writeS + 4x<0*256+128>], loadA2;
-:-:-:-:00 IADD trackA0.CC, trackA0, param_lda8;
-:-:-:-:00 IADD.X trackA1, trackA1, RZ;
-:-:-:-:00 @P6 STS.64 [writeS + 4x<4*256>], loadB0;
-:-:-:-:00 @P6 STS.64 [writeS + 4x<4*256+128>], loadB2;
-:-:-:-:00 IADD trackB0.CC, trackB0, param_ldb8;
-:-:-:-:00 IADD.X trackB1, trackB1, RZ;
-:-:D:S:00 ISETP.GE.AND P2, PT, k, 8, P5;
-:-:D:S:00 ISETP.GE.AND P3, PT, k, 8, P6;
-:-:D:-:00 @P2 LDG.E.128 loadA, [trackA];
-:-:D:-:00 @P3 LDG.E.128 loadB, [trackB];
-:-:-:-:00 @!P2 LDS.U.128 loadA, [RZ + addr_zero];
-:-:-:-:00 @!P3 LDS.U.128 loadB, [RZ + addr_zero];
-:-:-:-:00 LOP.XOR readAs, readAs, 4x<256*4*2>;
-:-:-:-:00 LOP.XOR readBs, readBs, 4x<256*4*2>;
-:-:D:-:00 BAR.SYNC 0;
-:-:-:-:00 LOP.XOR writeS, writeS, 4x<256*4*2>;
<CODE>
    my $k_end = 8;
    our %insert =
    (
#----------------------------------------------------------------------
### k means data remained to compute before this iter
### this iter would consume 4. if k<8 before this iter, then after this iter, data remained to compute < 4,
### it means, during this iter, we should not do LDG A/B, to prepare data for next iter 
### P2 = (k >= 8) && (txa < m) for LDG A
### P3 = (k >= 8) && (txb < n) for LDG B
j3c65 => "T:-:D:S:00 ISETP.GE.AND P2, PT, k, $k_end, P5;\n",
j3c89 => "T:-:D:S:00 ISETP.GE.AND P3, PT, k, $k_end, P6;\n",
j0c5  => "T:-:D:S:00 ISETP.GE.AND P0, PT, k, $k_end, PT;\n",
#----------------------------------------------------------------------
### load A/B
j3c77  => "T:-:D:S:00 \@P2 LDG.E.128 loadA, [trackA];\n",
j3c101 => "T:-:D:S:00 \@P3 LDG.E.128 loadB, [trackB];\n" ,
#----------------------------------------------------------------------
### A:next col
### B:next row
j1c5   => "T:-:D:S:00 \@P2 IADD trackA0.CC, trackA0, param_lda8;\n",
j1c35  => "T:-:D:S:00 \@P2 IADD.X trackA1, trackA1, RZ;\n",
j1c77  => "T:-:D:S:00 \@P3 IADD trackB0.CC, trackB0, param_ldb8;\n",
j1c143 => "T:-:D:S:00 \@P3 IADD.X trackB1, trackB1, RZ;\n",
#----------------------------------------------------------------------
### STS A/B
j2c35  => "T:-:D:S:00 \@P0 TEXDEPBAR 0x0;\n",
j2c65  => "-:-:D:S:00 \@P0 STS.64 [writeS + 4x<0*256>], loadA0;\n",
j2c71  => "-:-:D:S:00 \@P0 STS.64 [writeS + 4x<0*256+128>], loadA2;\n",
j2c101 => "-:-:D:S:00 \@P0 STS.64 [writeS + 4x<4*256>], loadB0;\n",
j2c107 => "-:-:D:S:00 \@P0 STS.64 [writeS + 4x<4*256+128>], loadB2;\n",
#----------------------------------------------------------------------
### sync for STS A/B
j2c143 => "T:-:D:S:00 \@P0 BAR.SYNC 0;\n" ,
#----------------------------------------------------------------------
### swap readAs/readBs/writeS
j2c125 => "T:-:D:S:00 \@P0 LOP.XOR readAs, readAs, 4x<256*4*2>;\n" ,
j3c5   => "T:-:D:S:00 \@P0 LOP.XOR readBs, readBs, 4x<256*4*2>;\n" ,
j3c107 => "T:-:D:S:00 \@P0 LOP.XOR writeS, writeS, 4x<256*4*2>;\n" ,
#----------------------------------------------------------------------
### k-4
j1c71  => "-:-:-:-:00 IADD32I k, k, -4;\n",
#----------------------------------------------------------------------
### branch
j3c143 => "T:-:D:S:00 \@P0 BRA.U LOOP;\n" .
"-:-:D:-:00 \@P1 BRA.U REMAINDER;\n",
#----------------------------------------------------------------------
    );
    return;
</CODE>

<INCLUDE file="sgemm_common_192x192.sass"/>
